import json
import sys
import re
import torch
import torch.nn.functional as F
from sentence_transformers import util
from sentence_transformers import SentenceTransformer
from oai_trained_model import Model
sys.path.append("../utils")
from parse_code import preparing_code_for_parsing, get_apis_used
from retrieve_top_string import retrieve_top_similar_strings
from get_embedding import get_azure_llm_emb

ignore_list = ['sync', '', 'Excel', 'ExcelScript', 'workbook', 'worksheets', 'context']
choice = 0
pattern = "(.+?)\("

with open(r"C:\t-avikdutta\RAR paper\code\src\cache_data\doc_str_emb_mapping.json", "r") as file:
    doc_ent_mapping = json.load(file)
with open(f"../src/cache_data/doc_embeddings_sbert_choice{choice}.json", "r") as file:
    api_emb_mapping = json.load(file)
with open(f"../src/cache_data/doc_embeddings.json", "r") as file:
    api_oaiemb_mapping = json.load(file)
with open(r"C:\t-avikdutta\RAR paper\code\src\cache_data\InstructExcelCF_query_emb_sbert_choice{}.json".format(choice), "r") as file:
    query_emb_mapping = json.load(file)
with open(r"C:\t-avikdutta\RAR paper\code\src\cache_data\example_corpus_query_emb_sbert_choice{}.json".format(choice), "r") as file:
    example_query_emb_mapping = json.load(file)
with open(r"C:\t-avikdutta\RAR paper\code\src\cache_data\example_corpus_query_emb.json", "r") as file:
    query_oaiemb_mapping = json.load(file)

def get_docpg_for_api(api, pg_dict_nodes):
    global pattern
    api_to_check = api.split('.')[-1]
    if api_to_check in ignore_list:
        return None
    selection_list = []
    for uid, node in pg_dict_nodes.items():
        to_check_list = [node.name.split('.')[-1]]
        if node.properties:
            to_check_list = to_check_list + [property['name'] for property in node.properties]
        if node.methods:
            to_check_list = to_check_list + [re.search(pattern, method['name']).group(1) for method in node.methods]
        if node.fields:
            to_check_list = to_check_list + [field['name'] for field in node.fields]

        if api_to_check in to_check_list:
            selection_list.append(uid)
    if len(selection_list) == 0 or len(api) == 0:
        return None
    else:
        index = retrieve_top_similar_strings(api, selection_list)
        return selection_list[index[0]]
    
def extract_raw_apis_from_code(code: str)-> list | None:
    proc_code = preparing_code_for_parsing(code)
    if proc_code is None:
        return None
    try:
        extracted_apis = get_apis_used(proc_code)
    except:
        return None
    return extracted_apis 

def map_rawapis_to_docent(extracted_apis, pg_node_dict)-> list:
    mapped_docent = []
    for api in extracted_apis:
        api_pg_uid = get_docpg_for_api(api, pg_node_dict)
        if api_pg_uid is None:
            continue
        api_to_match = api.split('.')[-1]
        if api_to_match == pg_node_dict[api_pg_uid].name.split('.')[-1]:
            continue
        elif pg_node_dict[api_pg_uid].type == 'enum':
            api_match = pg_node_dict[api_pg_uid].name
        else:
            api_match = pg_node_dict[api_pg_uid].name + "." + api.split('.')[-1]
        mapped_docent.append(api_match)
    return mapped_docent

def extract_good_apis(code: str, pg_node_dict):
    extracted_apis = extract_raw_apis_from_code(code)
    if extracted_apis is None:
        return None
    good_apis = map_rawapis_to_docent(extracted_apis, pg_node_dict)
    return list(set(good_apis))

def get_api_corpus(pg_node_dict)-> list:
    global pattern
    api_corpus = []
    for _, node in pg_node_dict.items():
        if node.type == 'enum':
            api_corpus.append(node.name)
        else:
            if node.properties:
                api_corpus = api_corpus + [node.name + "." + property['name'] for property in node.properties]
            if node.methods:
                api_corpus = api_corpus + [node.name + "." + re.search(pattern, method['name']).group(1) for method in node.methods]
    return api_corpus

def get_api_corpus_and_description(pg_node_dict):
    global pattern
    api_corpus = []
    for _, node in pg_node_dict.items():
        if node.type == 'enum':
            api_corpus.append(node.name)
        else:
            if node.properties:
                api_corpus = api_corpus + [node.name + "." + property['name'] for property in node.properties]
            if node.methods:
                api_corpus = api_corpus + [node.name + "." + re.search(pattern, method['name']).group(1) for method in node.methods]
    return api_corpus

def get_sim_values(test_query, api_corpus, test_or_example: str = "test"):
    global choice
    doc_api_embs = torch.tensor([api_emb_mapping[doc_ent_mapping[api]] if len(api.split('.')) == 2 else api_emb_mapping[doc_ent_mapping[api + f".{choice}"]] for api in api_corpus])
    if test_or_example == 'example':     
        query_emb = torch.tensor(example_query_emb_mapping[test_query])
    else:
        query_emb = torch.tensor(query_emb_mapping[test_query])
    similarities = util.pytorch_cos_sim(query_emb, doc_api_embs)[0]
    return similarities

# oai_model = Model()
# oai_model.load_state_dict(torch.load('./oai_dense_trained.pth'))
# oai_model.eval()

# def get_sim_values_for_sbert(test_query, api_corpus):
#     global choice
#     embeddings = torch.tensor([api_emb_mapping[doc_ent_mapping[api]] if len(api.split('.')) == 2 else api_emb_mapping[doc_ent_mapping[api + f".{choice}"]] for api in api_corpus])
#     target_embedding = torch.tensor(query_emb_mapping[test_query])
#     # embeddings = model.encode(str_list, convert_to_tensor=True)
#     cosine_scores = util.cos_sim(embeddings, target_embedding).tolist()
#     cosine_scores = torch.tensor([score[0] for score in cosine_scores])
#     return cosine_scores

def retrieve_topn_apis(similarities, api_corpus, topn: int):
    top_indices = similarities.argsort(descending=True)[:topn].tolist()
    return [api_corpus[idx] for idx in top_indices]

def retrieve_simthresh_apis(similarities, api_corpus, sim_thresh: float):
    ret_apis = []
    for i, sim in enumerate(similarities.tolist()):
        if sim >= sim_thresh:
            ret_apis.append(api_corpus[i])
    return ret_apis

def retrieve_bottomn_apis(similarities, api_corpus, bottomn: int):
    bottom_indices = similarities.argsort(descending = False)[:bottomn].tolist()
    return [api_corpus[idx] for idx in bottom_indices]

def calc_recall(ret_list, good_list):
    common = set(ret_list) & set(good_list)
    return len(common)/len(good_list)

def calc_precision(ret_list, good_list):
    common = set(ret_list) & set(good_list)
    return len(common)/len(ret_list)

def calc_jaccard(ret_list, good_list):
    ret_set = set(ret_list)
    good_set = set(good_list)
    intersection_len = len(ret_set.intersection(good_set))
    union_len = len(ret_set.union(good_set))
    return intersection_len/union_len

def oai_dense_model_sim_score(test_query, api_corpus):
    global choice
    similarities = []
    for api in api_corpus:
        text1_emb = torch.tensor([query_oaiemb_mapping[test_query]])
        text2_emb = torch.tensor([api_oaiemb_mapping[doc_ent_mapping[api]] if len(api.split('.')) == 2 else api_oaiemb_mapping[doc_ent_mapping[api + f".{choice}"]]])
        output1 = oai_model(text1_emb)
        output2 = oai_model(text2_emb)
        cosine_score = F.cosine_similarity(output1, output2)
        similarities.append(cosine_score.item())
    return torch.tensor(similarities)